package handler

import (
	"github.com/DiracLee/dires-go/app/cmdline"
	"github.com/DiracLee/dires-go/app/database"
	"github.com/DiracLee/dires-go/app/payload"
	"github.com/DiracLee/dires-go/utils"
	"strconv"
	"strings"
	"time"
)

const (
	upsertPolicy = iota // update or insert, default
	insertPolicy        // insert, i.e., set nx
	updatePolicy        // update, i.e., set ex

	unlimitedTTL int64 = 0

	ArgNX = "NX" // set insert policy
	ArgXX = "XX" // set update policy

	ArgEX = "EX" // EX seconds			| Set TTL in seconds
	ArgPX = "PX" // PX milliseconds 	| Set TTL in milliseconds
)

func init() {
	RegisterCommand(cmdline.CmdSet, handleSet, writeFirstKey, rollbackFirstKey, -3)
	RegisterCommand(cmdline.CmdSetNX, handleSetNX, writeFirstKey, rollbackFirstKey, 3)
	RegisterCommand(cmdline.CmdSetEX, handleSetEX, writeFirstKey, rollbackFirstKey, 4)
	RegisterCommand(cmdline.CmdPSetEX, handlePSetEX, writeFirstKey, rollbackFirstKey, 4)
	RegisterCommand(cmdline.CmdMSet, handleMSet, prepareMSet, undoMSet, -3)
	RegisterCommand(cmdline.CmdMSetNX, handleMSetNX, prepareMSet, undoMSet, -3)
	RegisterCommand(cmdline.CmdGet, handleGet, readFirstKey, nil, 2)
	RegisterCommand(cmdline.CmdMGet, handleMGet, prepareMGet, nil, -2)
	RegisterCommand(cmdline.CmdGetSet, handleGetSet, writeFirstKey, rollbackFirstKey, 3)
	RegisterCommand(cmdline.CmdIncr, handleIncr, writeFirstKey, rollbackFirstKey, 2)
	RegisterCommand(cmdline.CmdIncrBy, handleIncrBy, writeFirstKey, rollbackFirstKey, 3)
	RegisterCommand(cmdline.CmdIncrByFloat, handleIncrByFloat, writeFirstKey, rollbackFirstKey, 3)
	RegisterCommand(cmdline.CmdDecr, handleDecr, writeFirstKey, rollbackFirstKey, 2)
	RegisterCommand(cmdline.CmdDecrBy, handleDecrBy, writeFirstKey, rollbackFirstKey, 3)
	RegisterCommand(cmdline.CmdStrLen, handleStrLen, readFirstKey, nil, 2)
	RegisterCommand(cmdline.CmdAppend, handleAppend, writeFirstKey, rollbackFirstKey, 3)
	RegisterCommand(cmdline.CmdSetRange, handleSetRange, writeFirstKey, rollbackFirstKey, 4)
	RegisterCommand(cmdline.CmdGetRange, handleGetRange, readFirstKey, nil, 4)
}

func handleSet(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	value := args[1]
	policy := upsertPolicy
	ttl := unlimitedTTL

	// parse options
	if len(args) > 2 {
		for i := 2; i < len(args); i++ {
			arg := strings.ToUpper(string(args[i]))
			if arg == ArgNX { // insert
				if policy == updatePolicy {
					return &payload.SyntaxErrPayload{}
				}
				policy = insertPolicy
			} else if arg == ArgXX { // update policy
				if policy == insertPolicy {
					return &payload.SyntaxErrPayload{}
				}
				policy = updatePolicy
			} else if arg == ArgEX { // set ttl in seconds
				if ttl != unlimitedTTL {
					// ttl has been set
					return &payload.SyntaxErrPayload{}
				}
				if i+1 >= len(args) {
					// EX is the last arg, i.e., TTL is not specified
					return &payload.SyntaxErrPayload{}
				}
				ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
				if err != nil {
					// TTL is not integer
					return &payload.SyntaxErrPayload{}
				}
				if ttlArg <= 0 {
					// TTL is not positive
					return payload.NewErrPayload("ERR invalid expire time in set")
				}
				ttl = ttlArg * 1000
				i++ // skip next arg
			} else if arg == ArgPX { // set ttl in milliseconds
				if ttl != unlimitedTTL {
					return &payload.SyntaxErrPayload{}
				}
				if i+1 >= len(args) {
					return &payload.SyntaxErrPayload{}
				}
				ttlArg, err := strconv.ParseInt(string(args[i+1]), 10, 64)
				if err != nil {
					return &payload.SyntaxErrPayload{}
				}
				if ttlArg <= 0 {
					return payload.NewErrPayload("ERR invalid expire time in set")
				}
				ttl = ttlArg
				i++ // skip next arg
			} else {
				return &payload.SyntaxErrPayload{}
			}
		}
	}

	entity := &database.DataEntity{
		Data: value,
	}

	var result int
	switch policy {
	case upsertPolicy:
		db.PutOrSet(key, entity)
		result = 1
	case insertPolicy:
		result = db.PutIfNotExists(key, entity)
	case updatePolicy:
		result = db.SetIfExists(key, entity)
	}
	if result > 0 {
		if ttl != unlimitedTTL {
			expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
			db.SetKeyTTL(key, expireTime)
			db.AddAOF(cmdline.CmdLine{
				[]byte(cmdline.CmdSet),
				args[0],
				args[1],
			})
			db.AddAOF(ExpirePayload(key, expireTime).Args)
		} else {
			db.Persist(key) // set unlimited ttl
			db.AddAOF(NamedCommand(cmdline.CmdSet, args...))
		}
	}

	if result > 0 {
		return &payload.OkPayload{}
	}
	return &payload.NullBulkPayload{}
}

func handleSetNX(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	value := args[1]
	entity := &database.DataEntity{
		Data: value,
	}
	result := db.PutIfNotExists(key, entity)
	db.AddAOF(NamedCommand(cmdline.CmdSetNX, args...))
	return payload.NewIntPayload(int64(result))
}

func handleSetEX(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	value := args[2]

	ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
	if err != nil {
		return &payload.SyntaxErrPayload{}
	}
	if ttlArg <= 0 {
		return payload.NewErrPayload("ERR invalid expire time in setex")
	}
	ttl := ttlArg * 1000

	entity := &database.DataEntity{
		Data: value,
	}

	db.PutOrSet(key, entity)
	expireTime := time.Now().Add(time.Duration(ttl) * time.Millisecond)
	db.SetKeyTTL(key, expireTime)
	db.AddAOF(NamedCommand(cmdline.CmdSetEX, args...))
	db.AddAOF(ExpirePayload(key, expireTime).Args)
	return &payload.OkPayload{}
}

func handlePSetEX(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	value := args[2]

	ttlArg, err := strconv.ParseInt(string(args[1]), 10, 64)
	if err != nil {
		return &payload.SyntaxErrPayload{}
	}
	if ttlArg <= 0 {
		return payload.NewErrPayload("ERR invalid expire time in setex")
	}

	entity := &database.DataEntity{
		Data: value,
	}

	db.PutOrSet(key, entity)
	expireTime := time.Now().Add(time.Duration(ttlArg) * time.Millisecond)
	db.SetKeyTTL(key, expireTime)
	db.AddAOF(NamedCommand(cmdline.CmdPSetEX, args...))
	db.AddAOF(ExpirePayload(key, expireTime).Args)

	return &payload.OkPayload{}
}

func handleMSet(db database.DB, args cmdline.CmdLine) payload.Payload {
	if len(args)%2 != 0 {
		return payload.NewSyntaxErrPayload()
	}

	size := len(args) / 2
	keys := make([]string, size)
	values := make([][]byte, size)
	for i := 0; i < size; i++ {
		keys[i] = string(args[2*i])
		values[i] = args[2*i+1]
	}

	for i, key := range keys {
		value := values[i]
		db.PutOrSet(key, &database.DataEntity{Data: value})
	}
	db.AddAOF(NamedCommand(cmdline.CmdMSet, args...))
	return &payload.OkPayload{}
}

func handleMSetNX(db database.DB, args cmdline.CmdLine) payload.Payload {
	if len(args)%2 != 0 {
		return payload.NewSyntaxErrPayload()
	}
	size := len(args) / 2
	values := make([][]byte, size)
	keys := make([]string, size)
	for i := 0; i < size; i++ {
		keys[i] = string(args[2*i])
		values[i] = args[2*i+1]
	}

	for _, key := range keys {
		_, exists := db.Get(key)
		if exists {
			return payload.NewIntPayload(0)
		}
	}

	for i, key := range keys {
		value := values[i]
		db.PutOrSet(key, &database.DataEntity{Data: value})
	}
	db.AddAOF(NamedCommand(cmdline.CmdMSetNX, args...))
	return payload.NewIntPayload(1)
}

func undoMSet(db database.DB, args cmdline.CmdLine) []cmdline.CmdLine {
	writeKeys, _ := prepareMSet(args)
	return rollbackGivenKeys(db, writeKeys...)
}

func handleGet(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	bytes, err := db.GetAsString(key)
	if err != nil {
		return err
	}
	if bytes == nil {
		return &payload.NullBulkPayload{}
	}
	return payload.NewBulkPayload(bytes)
}

func handleMGet(db database.DB, args cmdline.CmdLine) payload.Payload {
	keys := make([]string, len(args))
	for i, v := range args {
		keys[i] = string(v)
	}

	result := make([][]byte, len(args))
	for i, key := range keys {
		bytes, err := db.GetAsString(key)
		if err != nil {
			_, isWrongType := err.(*payload.WrongTypeErrPayload)
			if isWrongType {
				result[i] = nil
				continue
			} else {
				return err
			}
		}
		result[i] = bytes // nil or []byte
	}

	return payload.NewMultiBulkPayload(result)
}

func handleGetSet(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	value := args[1]

	old, err := db.GetAsString(key)
	if err != nil {
		return err
	}

	db.PutOrSet(key, &database.DataEntity{Data: value})
	db.Persist(key) // override ttl
	db.AddAOF(NamedCommand(cmdline.CmdGetSet, args...))
	if old == nil {
		return new(payload.NullBulkPayload)
	}
	return payload.NewBulkPayload(old)
}

func handleIncr(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])

	bytes, err := db.GetAsString(key)
	if err != nil {
		return err
	}
	if bytes != nil {
		val, err := strconv.ParseInt(string(bytes), 10, 64)
		if err != nil {
			return payload.NewErrPayload("ERR value is not an integer or out of range")
		}
		db.PutOrSet(key, &database.DataEntity{
			Data: []byte(strconv.FormatInt(val+1, 10)),
		})
		db.AddAOF(NamedCommand(cmdline.CmdIncr, args...))
		return payload.NewIntPayload(val + 1)
	}
	db.PutOrSet(key, &database.DataEntity{
		Data: []byte("1"),
	})
	db.AddAOF(NamedCommand(cmdline.CmdIncr, args...))
	return payload.NewIntPayload(1)
}

func handleIncrBy(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	rawDelta := string(args[1])
	delta, err := strconv.ParseInt(rawDelta, 10, 64)
	if err != nil {
		return payload.NewErrPayload("ERR value is not an integer or out of range")
	}

	bytes, errPayload := db.GetAsString(key)
	if errPayload != nil {
		return errPayload
	}
	if bytes != nil {
		// existed value
		val, err := strconv.ParseInt(string(bytes), 10, 64)
		if err != nil {
			return payload.NewErrPayload("ERR value is not an integer or out of range")
		}
		db.PutOrSet(key, &database.DataEntity{
			Data: []byte(strconv.FormatInt(val+delta, 10)),
		})
		db.AddAOF(NamedCommand(cmdline.CmdIncrBy, args...))
		return payload.NewIntPayload(val + delta)
	}
	db.PutOrSet(key, &database.DataEntity{
		Data: args[1],
	})
	db.AddAOF(NamedCommand(cmdline.CmdIncrBy, args...))
	return payload.NewIntPayload(delta)
}

func handleIncrByFloat(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	rawDelta := string(args[1])
	delta, err := utils.NewFromString(rawDelta)
	if err != nil {
		return payload.NewErrPayload("ERR value is not a valid float")
	}

	bytes, errPayload := db.GetAsString(key)
	if errPayload != nil {
		return errPayload
	}
	if bytes != nil {
		val, err := utils.NewFromString(string(bytes))
		if err != nil {
			return payload.NewErrPayload("ERR value is not a valid float")
		}
		resultBytes := []byte(val.Add(delta).String())
		db.PutOrSet(key, &database.DataEntity{
			Data: resultBytes,
		})
		db.AddAOF(NamedCommand(cmdline.CmdIncrByFloat, args...))
		return payload.NewBulkPayload(resultBytes)
	}
	db.PutOrSet(key, &database.DataEntity{
		Data: args[1],
	})
	db.AddAOF(NamedCommand(cmdline.CmdIncrByFloat, args...))
	return payload.NewBulkPayload(args[1])
}

func handleDecr(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])

	bytes, errPayload := db.GetAsString(key)
	if errPayload != nil {
		return errPayload
	}
	if bytes != nil {
		val, err := strconv.ParseInt(string(bytes), 10, 64)
		if err != nil {
			return payload.NewErrPayload("ERR value is not an integer or out of range")
		}
		db.PutOrSet(key, &database.DataEntity{
			Data: []byte(strconv.FormatInt(val-1, 10)),
		})
		db.AddAOF(NamedCommand(cmdline.CmdDecr, args...))
		return payload.NewIntPayload(val - 1)
	}
	entity := &database.DataEntity{
		Data: []byte("-1"),
	}
	db.PutOrSet(key, entity)
	db.AddAOF(NamedCommand(cmdline.CmdDecr, args...))
	return payload.NewIntPayload(-1)
}

func handleDecrBy(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	rawDelta := string(args[1])
	delta, err := strconv.ParseInt(rawDelta, 10, 64)
	if err != nil {
		return payload.NewErrPayload("ERR value is not an integer or out of range")
	}

	bytes, errPayload := db.GetAsString(key)
	if errPayload != nil {
		return errPayload
	}
	if bytes != nil {
		val, err := strconv.ParseInt(string(bytes), 10, 64)
		if err != nil {
			return payload.NewErrPayload("ERR value is not an integer or out of range")
		}
		db.PutOrSet(key, &database.DataEntity{
			Data: []byte(strconv.FormatInt(val-delta, 10)),
		})
		db.AddAOF(NamedCommand(cmdline.CmdDecrBy, args...))
		return payload.NewIntPayload(val - delta)
	}
	valueStr := strconv.FormatInt(-delta, 10)
	db.PutOrSet(key, &database.DataEntity{
		Data: []byte(valueStr),
	})
	db.AddAOF(NamedCommand(cmdline.CmdDecrBy, args...))
	return payload.NewIntPayload(-delta)
}

func handleStrLen(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	bytes, err := db.GetAsString(key)
	if err != nil {
		return err
	}
	if bytes == nil {
		return payload.NewIntPayload(0)
	}
	return payload.NewIntPayload(int64(len(bytes)))
}

func handleAppend(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	bytes, err := db.GetAsString(key)
	if err != nil {
		return err
	}
	bytes = append(bytes, args[1]...)
	db.PutOrSet(key, &database.DataEntity{
		Data: bytes,
	})
	db.AddAOF(NamedCommand(cmdline.CmdAppend, args...))
	return payload.NewIntPayload(int64(len(bytes)))
}

func handleSetRange(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	offset, errNative := strconv.ParseInt(string(args[1]), 10, 64)
	if errNative != nil {
		return payload.NewErrPayload(errNative.Error())
	}
	value := args[2]
	bytes, err := db.GetAsString(key)
	if err != nil {
		return err
	}
	bytesLen := int64(len(bytes))
	if bytesLen < offset {
		diff := offset - bytesLen
		diffArray := make([]byte, diff)
		bytes = append(bytes, diffArray...)
		bytesLen = int64(len(bytes))
	}
	for i := 0; i < len(value); i++ {
		idx := offset + int64(i)
		if idx >= bytesLen {
			bytes = append(bytes, value[i])
		} else {
			bytes[idx] = value[i]
		}
	}
	db.PutOrSet(key, &database.DataEntity{
		Data: bytes,
	})
	db.AddAOF(NamedCommand(cmdline.CmdSetRange, args...))
	return payload.NewIntPayload(int64(len(bytes)))
}

func handleGetRange(db database.DB, args cmdline.CmdLine) payload.Payload {
	key := string(args[0])
	startIdx, errNative := strconv.ParseInt(string(args[1]), 10, 64)
	if errNative != nil {
		return payload.NewErrPayload(errNative.Error())
	}
	endIdx, errNative := strconv.ParseInt(string(args[2]), 10, 64)
	if errNative != nil {
		return payload.NewErrPayload(errNative.Error())
	}

	bytes, err := db.GetAsString(key)
	if err != nil {
		return err
	}

	if bytes == nil {
		return payload.NewNullBulkPayload()
	}

	bytesLen := int64(len(bytes))
	if startIdx < -1*bytesLen {
		return &payload.NullBulkPayload{}
	} else if startIdx < 0 {
		startIdx = bytesLen + startIdx
	} else if startIdx >= bytesLen {
		return &payload.NullBulkPayload{}
	}
	if endIdx < -1*bytesLen {
		return &payload.NullBulkPayload{}
	} else if endIdx < 0 {
		endIdx = bytesLen + endIdx + 1
	} else if endIdx < bytesLen {
		endIdx = endIdx + 1
	} else {
		endIdx = bytesLen
	}
	if startIdx > endIdx {
		return payload.NewNullBulkPayload()
	}

	return payload.NewBulkPayload(bytes[startIdx:endIdx])
}
